--- external/gsn/models/nerf_utils.py	2023-03-28 16:59:06.690064402 +0000
+++ external_reference/gsn/models/nerf_utils.py	2023-03-28 16:00:21.946531010 +0000
@@ -141,7 +141,7 @@
 
     return pts, viewdirs, z_vals, rd, ro
 
-
+# CHANGED: modified to use inverse-depth (disparity), and change thickness of final point on ray
 def volume_render_radiance_field(
     rgb,
     occupancy,
@@ -151,16 +151,13 @@
     alpha_activation='relu',
     activate_rgb=True,
     density_bias=0,
+    use_disp=True,
 ):
 
-    one_e_10 = torch.tensor([1e10], dtype=ray_directions.dtype, device=ray_directions.device)
-    dists = torch.cat(
-        (
-            depth_values[..., 1:] - depth_values[..., :-1],
-            one_e_10.expand(depth_values[..., :1].shape),
-        ),
-        dim=-1,
-    )
+    # CHANGED: assume that the last sample is the same thickness as previous
+    # samples, rather than one_e_10 (which makes it hard for model to learn transparency)
+    dists = depth_values[..., 1:] - depth_values[..., :-1]
+    dists = torch.cat([dists, dists[..., [-1]]], dim=-1)
     dists = dists * ray_directions[..., None, :].norm(p=2, dim=-1)
 
     noise = 0.0
@@ -184,7 +181,9 @@
     weights = alpha * cumprod_exclusive(1.0 - alpha + 1e-10)
 
     if activate_rgb:
-        rgb = torch.sigmoid(rgb)
+        # CHANGED: use [-1, 1] activation
+        rgb = torch.tanh(rgb/2) # scaled sigmoid [-1, 1]
+        # rgb = torch.sigmoid(rgb)
 
         # widened sigmoid from https://github.com/google/mipnerf/blob/main/internal/models.py#L123
         rgb_padding = 0.001
@@ -193,7 +192,11 @@
     rgb_map = weights[..., None] * rgb
     rgb_map = rgb_map.sum(dim=-2)
 
-    depth_map = weights * depth_values
+    # CHANGED: return disparity rather than depth
+    if use_disp:
+        depth_map = weights * 1/depth_values
+    else:
+        depth_map = weights * depth_values
     depth_map = depth_map.sum(dim=-1)
 
     acc_map = weights.sum(dim=-1)
@@ -205,7 +208,13 @@
         torch.log(0.1 + alpha.view(alpha.size(0), -1)) + torch.log(0.1 + 1.0 - alpha.view(alpha.size(0), -1)) - -2.20727
     )
 
-    return rgb_map, disp_map, acc_map, weights, depth_map, occupancy_prior
+    # used for opacity regularization
+    extras = {
+        'dists': dists,
+        'alpha': alpha,
+    }
+
+    return rgb_map, disp_map, acc_map, weights, depth_map, occupancy_prior, extras
 
 
 def sample_pdf_2(bins, weights, num_samples, det=False):
